# +~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~ #  
#
#' @title  Sample tweets for round two of crowd-sourced elite criticism coding
#' @author Hauke Licht
#
# +~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~ #

# setup ----

# load packages
library(readr)
library(dplyr)
library(purrr)

# define dir paths
base_path <- file.path(".")
data_path <- file.path(base_path, "data")
fits_path <- file.path(data_path, "fits")
labelings_path <- file.path(data_path, "intermediate", "labelings")

helpers_path <- file.path(base_path, "code", "helpers")
source(file.path(helpers_path, "utils.R"))

altlabels <- label_map <- c(
  "yes-general" = "general"
  , "yes-specific" = "specific"
  , "yes-unsure" = "unsure"
  , "no" = "no"
)

# predict labels/load labeled tweets ----
fp <- file.path(labelings_path, "glmnet_predicted_elitecriticism_labels_sample_1.rds")

if (file.exists(fp)) {
  dat <- read_rds(fp)
} else {
  
  require(data.table, quietly = TRUE)
  require(caret, quietly = TRUE)
  
  
  # load (English) political tweets w/ clusters
  clustered_tweets <- read_rds(file.path(fits_path, "political_en_tweets_clustered.rds"))
  
  # add item ID
  clustered_tweets$item_id <- with(clustered_tweets, sprintf("%s_%s_%s", country_iso3c, user_id, status_id))
  
  # load tweet features
  # note: need to rely on old tweets (prior to update in June/May 2023) for reproducibility
  tfeats <- read_rds(file.path(data_path, "input", "parl_party_tweets_features_2020-04-24.rds"))
  
  key_cols <- c("country_iso3c", "party_id", "user_id", "status_id", "item_id")
  load(file.path(helper_path, "politicaltweets", "training.features.rda"))
  feature_vars <- training.features$colname
  
  # recode and add features
  tfeats[, item_id:= sprintf("%s_%s_%s", country_iso3c, user_id, status_id)] 
  tfeats[, tweet_type := factor(case_when(is_reply~"reply", is_quote~"quote", TRUE~"tweet"))]
  tfeats[, nchar_limit:= factor(date > lubridate::ymd("2017-11-07"), c(F, T), c("140", "280"))]
  
  # keep only relevant features
  tfeats <- tfeats[, .SD, .SDcols = c(key_cols, feature_vars[feature_vars %in% names(tfeats)])]
  
  # combine with clustering information
  dat <- left_join(clustered_tweets, as_tibble(tfeats))
  
  rm(clustered_tweets)
  rm(tfeats)
  gc()
  
  # load classifier trained on sample-1 labelings
  trained_glmnet <- read_rds(file.path(fits_path, "glmnet_classifier_sample_1_tweets.rds"))
  
  # predict label class probabilities
  tmp <- caret:::predict.train(trained_glmnet, dat, type = "prob")
  rm(trained_glmnet)
  
  # induce classification
  tmp$pred_label <- names(tmp)[apply(tmp, 1, which.max)]
  
  # add predictions to tweet features data
  dat <- bind_cols(dat, tmp)
  rm(tmp); gc()
  
  write_rds(dat, fp)
}

table(dat$pred_label)   

# sample tweets from LASER embedding-based clusters ----

tmp <- dat %>% 
  select(1:12, cluster_id, !!unique(.$pred_label), pred_label, sample) %>% 
  filter(is.na(sample), pred_label != "no")

cluster_unit_sizes <- tmp %>%
  group_by(cluster_id) %>%
  summarise(
    n_countries = n_distinct(country_iso3c)
    , n_parties = n_distinct(party_id)
    , cluster_size = n()
  ) %>%
  ungroup()

table(cluster_unit_sizes$n_parties)

determine_sample_size <- function(x) ceiling(log2(x))+1
s <- with(cluster_unit_sizes, determine_sample_size(n_parties))

fp <- file.path(data_path, "intermediate", "samples", "tweets_sample_2.rds")
if (!file.exists(fp)) {
  sampled_tweets <- tmp %>% 
    select(1:12, pred_label, cluster_id) %>% 
    split(.$cluster_id) %>% 
    map2_dfr(
      .x = .
      , .y = s    
      , function(.x, .y) {
        .x %>% 
          group_by(user_id) %>% 
          sample_n(1) %>% 
          ungroup() %>% 
          sample_n(.y) %>% 
          mutate(cluster_sample_size = .y)
      }
    ) %>% 
    sample_frac(1)
  
  write_rds(sampled_tweets, fp)
}

